# ------------------------------------------------------------------------------
# Trains split ensemble models on validation set
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 06_fit_ensemble.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(Matrix)
library(tidyverse)
library(glue)
library(optparse)
library(here) # here() relative filepaths
library(testit) # assert()
library(OneR) # bin()

u <- modules::use(here::here("lib", "util.R"))
temp <- here("code", "08_train_split_model", "temp")

# Load Data --------------------------------------------------------------------
message("Loading data...")
ids <- readRDS(file.path(temp, "split_val_cohort.rds"))
subscores <- readRDS(file.path(temp, "subscores_val_set.rds"))
train_df <- ids %>%
  u$safe_left_join(subscores)

# Subset Data ------------------------------------------------------------------
message("Subsetting data...")
keep_pop <- ids$test_010_day & !ids$excl_flag_chronic &
  !ids$excl_flag_c_int & !ids$excl_flag_chronic

train_df <- train_df[keep_pop, ]

# Fit Ensemble -----------------------------------------------------------------
message("Fitting ensembles...")
for(i in c(1,2)){
  message("Split ", i)
  split_var <- glue("sample_split{i}")
  keep_split <- which(train_df[[split_var]])

  train_df_split <- train_df[keep_split,]

  covariates <- c(
    glue("p__gbm__stent_or_cabg_010_day__tested__{i}"),
    glue("p__gbm__stent_or_cabg_010_day__tested__{i}")
  )
  form <- reformulate(
    response = "stent_or_cabg_010_day", termlabels = covariates
  )

  fit <- glm(
    form, data = train_df, family = "gaussian"
  )

  message("Saving model ", i, "...")
  save_fp <- file.path(temp, glue("ensemble__stent_or_cabg_010_day__tested__{i}.rds"))
  saveRDS(fit, save_fp)
}

message("Done.")
